{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T14:27:48.616499Z",
     "start_time": "2019-12-09T14:27:47.812156Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T14:27:48.622334Z",
     "start_time": "2019-12-09T14:27:48.619304Z"
    }
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 512\n",
    "EPOCHS = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T14:27:48.730140Z",
     "start_time": "2019-12-09T14:27:48.624639Z"
    }
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('/home/shenchenkai/data/mnist_pytorch', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('/home/shenchenkai/data/mnist_pytorch', train=False, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T14:28:14.542050Z",
     "start_time": "2019-12-09T14:28:14.485682Z"
    }
   },
   "outputs": [],
   "source": [
    "d = datasets.MNIST('/home/shenchenkai/data/mnist_pytorch', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T14:28:37.302150Z",
     "start_time": "2019-12-09T14:28:37.285791Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 28, 28])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d[0][0].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T12:42:44.440625Z",
     "start_time": "2019-12-09T12:42:44.424199Z"
    }
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.conv1 = nn.Sequential(   # (1,28,28)   \n",
    "            # in_channels:输入的通道数， out_channels：卷积核数量， kernel_size：卷积核大小， stride：步长\n",
    "            # stride=1时， padding=(kernel_size-1)/2， 图片长宽不变\n",
    "            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),    # (16,28,28)\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),    # (16,14,14)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),    # (32,14,14)\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),    # （32,7,7）\n",
    "        )\n",
    "        self.fc1 = nn.Sequential(\n",
    "            nn.Linear(32*7*7, 500),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)    # (batch,32,7,7)\n",
    "        x = x.view(x.size(0), -1)    # (batch,32*7*7)\n",
    "        x = self.fc1(x)\n",
    "        output = self.fc2(x)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T12:42:45.155182Z",
     "start_time": "2019-12-09T12:42:45.131540Z"
    }
   },
   "outputs": [],
   "source": [
    "model = CNN()\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T12:42:46.095286Z",
     "start_time": "2019-12-09T12:42:46.088809Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(model, train_loader, optimizer, loss_fn, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = loss_fn(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (batch_idx+1) % 30 == 0:\n",
    "            print('Train Epoch{} [{}/{}], loss:{:.4f}'.format(\n",
    "                    epoch, batch_idx*len(data), len(train_loader.dataset), loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T12:42:47.067173Z",
     "start_time": "2019-12-09T12:42:47.060890Z"
    }
   },
   "outputs": [],
   "source": [
    "def test(model, test_loader, loss_fn):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            output = model(data)\n",
    "            test_loss += loss_fn(output, target).item() # 将一批的损失相加\n",
    "            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-09T12:48:35.081418Z",
     "start_time": "2019-12-09T12:42:48.082740Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch1 [14848/60000], loss:0.2829\n",
      "Train Epoch1 [30208/60000], loss:0.1324\n",
      "Train Epoch1 [45568/60000], loss:0.0978\n",
      "\n",
      "Test set: Average loss: 0.0002, Accuracy: 9751/10000 (98%)\n",
      "\n",
      "Train Epoch2 [14848/60000], loss:0.0579\n",
      "Train Epoch2 [30208/60000], loss:0.0520\n",
      "Train Epoch2 [45568/60000], loss:0.0554\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9816/10000 (98%)\n",
      "\n",
      "Train Epoch3 [14848/60000], loss:0.0624\n",
      "Train Epoch3 [30208/60000], loss:0.0566\n",
      "Train Epoch3 [45568/60000], loss:0.0830\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9868/10000 (99%)\n",
      "\n",
      "Train Epoch4 [14848/60000], loss:0.0685\n",
      "Train Epoch4 [30208/60000], loss:0.0258\n",
      "Train Epoch4 [45568/60000], loss:0.0278\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9879/10000 (99%)\n",
      "\n",
      "Train Epoch5 [14848/60000], loss:0.0444\n",
      "Train Epoch5 [30208/60000], loss:0.0245\n",
      "Train Epoch5 [45568/60000], loss:0.0177\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9904/10000 (99%)\n",
      "\n",
      "Train Epoch6 [14848/60000], loss:0.0126\n",
      "Train Epoch6 [30208/60000], loss:0.0433\n",
      "Train Epoch6 [45568/60000], loss:0.0102\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9906/10000 (99%)\n",
      "\n",
      "Train Epoch7 [14848/60000], loss:0.0118\n",
      "Train Epoch7 [30208/60000], loss:0.0297\n",
      "Train Epoch7 [45568/60000], loss:0.0334\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9913/10000 (99%)\n",
      "\n",
      "Train Epoch8 [14848/60000], loss:0.0148\n",
      "Train Epoch8 [30208/60000], loss:0.0073\n",
      "Train Epoch8 [45568/60000], loss:0.0479\n",
      "\n",
      "Test set: Average loss: 0.0001, Accuracy: 9913/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, EPOCHS + 1):\n",
    "    train(model, train_loader, optimizer, loss_fn, epoch)\n",
    "    test(model, test_loader, loss_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-22T15:48:57.021262Z",
     "start_time": "2019-11-22T15:48:57.000523Z"
    }
   },
   "outputs": [],
   "source": [
    "# 保存整个模型\n",
    "torch.save(model, 'mnist_model.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-24T13:45:09.897358Z",
     "start_time": "2019-11-24T13:45:09.853628Z"
    }
   },
   "outputs": [],
   "source": [
    "# 保存模型参数\n",
    "torch.save(model.state_dict(), 'mnist_model_dict.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
